import argparse
import gc
import json
import logging
import os

import coremltools as ct
import coremltools.optimize.coreml as cto
import numpy as np

from python_coreml_stable_diffusion.torch2coreml import get_pipeline
from python_coreml_stable_diffusion.mixed_bit_compression_pre_analysis import (
    NBITS,
    PALETTIZE_MIN_SIZE as MIN_SIZE
)


logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def main(args):
    # Load Core ML model
    coreml_model = ct.models.MLModel(args.mlpackage_path, compute_units=ct.ComputeUnit.CPU_ONLY)
    logger.info(f"Loaded {args.mlpackage_path}")

    # Load palettization recipe
    with open(args.pre_analysis_json_path, 'r') as f:
        pre_analysis = json.load(f)

    if args.selected_recipe not in list(pre_analysis["recipes"]):
        raise KeyError(
            f"--selected-recipe ({args.selected_recipe}) not found in "
            f"--pre-analysis-json-path ({args.pre_analysis_json_path}). "
            f" Available recipes: {list(pre_analysis['recipes'])}"
        )


    recipe = pre_analysis["recipes"][args.selected_recipe]
    assert all(nbits in NBITS + [16] for nbits in recipe.values()), \
        f"Some nbits values in the recipe are illegal. Allowed values: {NBITS}"

    # Hash tensors to be able to match torch tensor names to mil tensors
    def get_tensor_hash(tensor):
        assert tensor.dtype == np.float16
        return tensor.ravel()[0] + np.prod(tensor.shape)

    args.model_version = pre_analysis["model_version"]
    pipe = get_pipeline(args)
    torch_model = pipe.unet

    hashed_recipe = {}
    for torch_module_name, nbits in recipe.items():
        tensor = [
            tensor.cpu().numpy().astype(np.float16) for name,tensor in torch_model.named_parameters()
            if name == torch_module_name + '.weight'
        ][0]
        hashed_recipe[get_tensor_hash(tensor)] = nbits

    del pipe
    gc.collect()

    op_name_configs = {}
    weight_metadata = cto.get_weights_metadata(coreml_model, weight_threshold=MIN_SIZE)
    hashes = np.array(list(hashed_recipe))
    for name, metadata in weight_metadata.items():
        # Look up target bits for this weight
        tensor_hash = get_tensor_hash(metadata.val)
        pdist = np.abs(hashes - tensor_hash)
        assert(pdist.min() < 0.01)
        matched = pdist.argmin()
        target_nbits = hashed_recipe[hashes[matched]]

        if target_nbits == 16:
            continue

        op_name_configs[name] = cto.OpPalettizerConfig(
            mode="kmeans",
            nbits=target_nbits,
            weight_threshold=int(MIN_SIZE)
        )

    config = ct.optimize.coreml.OptimizationConfig(op_name_configs=op_name_configs)
    coreml_model = ct.optimize.coreml.palettize_weights(coreml_model, config)

    coreml_model.save(args.o)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-o",
        required=True,
        help="Output directory to save the custom palettized model"
    )
    parser.add_argument(
        "--mlpackage-path",
        required=True,
        help="Path to .mlpackage model to be palettized"
    )
    parser.add_argument(
        "--pre-analysis-json-path",
        required=True,
        type=str,
        help=("The JSON file generated by mixed_bit_compression_pre_analysis.py"
    ))
    parser.add_argument(
        "--selected-recipe",
        required=True,
        type=str,
        help=("The string key into --pre-analysis-json-path's baselines dict"
    ))
    parser.add_argument(
        "--custom-vae-version",
        type=str,
        default=None,
        help=
        ("Custom VAE checkpoint to override the pipeline's built-in VAE. "
            "If specified, the specified VAE will be converted instead of the one associated to the `--model-version` checkpoint. "
            "No precision override is applied when using a custom VAE."
    ))

    args = parser.parse_args()

    if not os.path.exists(args.mlpackage_path):
        raise FileNotFoundError
    if not os.path.exists(args.pre_analysis_json_path):
        raise FileNotFoundError
    if not args.pre_analysis_json_path.endswith('.json'):
        raise ValueError("--recipe-json-path should end with '.json'")

    main(args)
